{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<small><i>This notebook was put together by [Jake Vanderplas](http://www.vanderplas.com). Source and license info is on [GitHub](https://github.com/jakevdp/sklearn_tutorial/).</i></small>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Supervised Learning In-Depth: Support Vector Machines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Previously we introduced supervised machine learning.\n",
    "There are many supervised learning algorithms available; here we'll go into brief detail one of the most powerful and interesting methods: **Support Vector Machines (SVMs)**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "\n",
    "plt.style.use('seaborn')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Motivating Support Vector Machines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Support Vector Machines (SVMs) are a powerful supervised learning algorithm used for **classification** or for **regression**. SVMs are a **discriminative** classifier: that is, they draw a boundary between clusters of data.\n",
    "\n",
    "Let's show a quick example of support vector classification. First we need to create a dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets.samples_generator import make_blobs\n",
    "X, y = make_blobs(n_samples=50, centers=2,\n",
    "                  random_state=0, cluster_std=0.60)\n",
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A discriminative classifier attempts to draw a line between the two sets of data. Immediately we see a problem: such a line is ill-posed! For example, we could come up with several possibilities which perfectly discriminate between the classes in this example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xfit = np.linspace(-1, 3.5)\n",
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "\n",
    "for m, b in [(1, 0.65), (0.5, 1.6), (-0.2, 2.9)]:\n",
    "    plt.plot(xfit, m * xfit + b, '-k')\n",
    "\n",
    "plt.xlim(-1, 3.5);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These are three *very* different separaters which perfectly discriminate between these samples. Depending on which you choose, a new data point will be classified almost entirely differently!\n",
    "\n",
    "How can we improve on this?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Support Vector Machines: Maximizing the *Margin*\n",
    "\n",
    "Support vector machines are one way to address this.\n",
    "What support vector machined do is to not only draw a line, but consider a *region* about the line of some given width.  Here's an example of what it might look like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xfit = np.linspace(-1, 3.5)\n",
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "\n",
    "for m, b, d in [(1, 0.65, 0.33), (0.5, 1.6, 0.55), (-0.2, 2.9, 0.2)]:\n",
    "    yfit = m * xfit + b\n",
    "    plt.plot(xfit, yfit, '-k')\n",
    "    plt.fill_between(xfit, yfit - d, yfit + d, edgecolor='none', color='#AAAAAA', alpha=0.4)\n",
    "\n",
    "plt.xlim(-1, 3.5);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice here that if we want to maximize this width, the middle fit is clearly the best.\n",
    "This is the intuition of **support vector machines**, which optimize a linear discriminant model in conjunction with a **margin** representing the perpendicular distance between the datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Fitting a Support Vector Machine\n",
    "\n",
    "Now we'll fit a Support Vector Machine Classifier to these points. While the mathematical details of the likelihood model are interesting, we'll let you read about those elsewhere. Instead, we'll just treat the scikit-learn algorithm as a black box which accomplishes the above task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC  # \"Support Vector Classifier\"\n",
    "clf = SVC(kernel='linear')\n",
    "clf.fit(X, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To better visualize what's happening here, let's create a quick convenience function that will plot SVM decision boundaries for us:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_svc_decision_function(clf, ax=None):\n",
    "    \"\"\"Plot the decision function for a 2D SVC\"\"\"\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "    x = np.linspace(plt.xlim()[0], plt.xlim()[1], 30)\n",
    "    y = np.linspace(plt.ylim()[0], plt.ylim()[1], 30)\n",
    "    Y, X = np.meshgrid(y, x)\n",
    "    P = np.zeros_like(X)\n",
    "    for i, xi in enumerate(x):\n",
    "        for j, yj in enumerate(y):\n",
    "            P[i, j] = clf.decision_function([[xi, yj]])\n",
    "    # plot the margins\n",
    "    ax.contour(X, Y, P, colors='k',\n",
    "               levels=[-1, 0, 1], alpha=0.5,\n",
    "               linestyles=['--', '-', '--'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "plot_svc_decision_function(clf);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that the dashed lines touch a couple of the points: these points are the pivotal pieces of this fit, and are known as the *support vectors* (giving the algorithm its name).\n",
    "In scikit-learn, these are stored in the ``support_vectors_`` attribute of the classifier:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "plot_svc_decision_function(clf)\n",
    "plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],\n",
    "            s=200, facecolors='none');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's use IPython's ``interact`` functionality to explore how the distribution of points affects the support vectors and the discriminative fit.\n",
    "(This is only available in IPython 2.0+, and will not work in a static view)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import interact\n",
    "\n",
    "def plot_svm(N=10):\n",
    "    X, y = make_blobs(n_samples=200, centers=2,\n",
    "                      random_state=0, cluster_std=0.60)\n",
    "    X = X[:N]\n",
    "    y = y[:N]\n",
    "    clf = SVC(kernel='linear')\n",
    "    clf.fit(X, y)\n",
    "    plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "    plt.xlim(-1, 4)\n",
    "    plt.ylim(-1, 6)\n",
    "    plot_svc_decision_function(clf, plt.gca())\n",
    "    plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],\n",
    "                s=200, facecolors='none')\n",
    "    \n",
    "interact(plot_svm, N=[10, 200], kernel='linear');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice the unique thing about SVM is that only the support vectors matter: that is, if you moved any of the other points without letting them cross the decision boundaries, they would have no effect on the classification results!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Going further: Kernel Methods\n",
    "\n",
    "Where SVM gets incredibly exciting is when it is used in conjunction with *kernels*.\n",
    "To motivate the need for kernels, let's look at some data which is not linearly separable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets.samples_generator import make_circles\n",
    "X, y = make_circles(100, factor=.1, noise=.1)\n",
    "\n",
    "clf = SVC(kernel='linear').fit(X, y)\n",
    "\n",
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "plot_svc_decision_function(clf);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clearly, no linear discrimination will ever separate these data.\n",
    "One way we can adjust this is to apply a **kernel**, which is some functional transformation of the input data.\n",
    "\n",
    "For example, one simple model we could use is a **radial basis function**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = np.exp(-(X[:, 0] ** 2 + X[:, 1] ** 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we plot this along with our data, we can see the effect of it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits import mplot3d\n",
    "\n",
    "def plot_3D(elev=30, azim=30):\n",
    "    ax = plt.subplot(projection='3d')\n",
    "    ax.scatter3D(X[:, 0], X[:, 1], r, c=y, s=50, cmap='spring')\n",
    "    ax.view_init(elev=elev, azim=azim)\n",
    "    ax.set_xlabel('x')\n",
    "    ax.set_ylabel('y')\n",
    "    ax.set_zlabel('r')\n",
    "\n",
    "interact(plot_3D, elev=(-90, 90), azip=(-180, 180));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that with this additional dimension, the data becomes trivially linearly separable!\n",
    "This is a relatively simple kernel; SVM has a more sophisticated version of this kernel built-in to the process. This is accomplished by using ``kernel='rbf'``, short for *radial basis function*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = SVC(kernel='rbf')\n",
    "clf.fit(X, y)\n",
    "\n",
    "plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='spring')\n",
    "plot_svc_decision_function(clf)\n",
    "plt.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1],\n",
    "            s=200, facecolors='none');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here there are effectively $N$ basis functions: one centered at each point! Through a clever mathematical trick, this computation proceeds very efficiently using the \"Kernel Trick\", without actually constructing the matrix of kernel evaluations.\n",
    "\n",
    "We'll leave SVMs for the time being and take a look at another classification algorithm: Random Forests."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
